from . import adb

import frida
import hashlib
import json
import threading
import glob
import os
import queue

md5 = lambda bs: hashlib.md5(bs).hexdigest()

from loguru import logger

class Message:
    pass

class MessageEvent(Message):
    def __init__(self, index, function_callee, function_call, params, ret_value):
        self.index = index
        self.from_method = function_call
        self.to_method = function_callee
        self.params = params
        self.ret_value = ret_value

class MessageSystem(Message):
    def __init__(self, index, function_callee, function_call, params, information):
        self.index = index
        self.from_method = function_call
        self.to_method = function_callee
        self.params = params
        self.ret_value = information

class Pentest:
    def __init__(self):
        self.ag_session = None
        self.package_name = None
        self.device = None
        self.frida_session = None
        self.pid = -1
        self.detached = False
        self.scripts = []
        self.pending = []
        self.list_file_scripts = []
        self.ag_scripts = ['androguard/pentest/internal/utils.js']
        self.idx = 0
        self.message_queue = queue.Queue()

    def is_detached(self):
        return self.detached
    
    def disconnect(self):
        logger.info("Disconnected from frida server")

        if self.scripts:
            for script in self.scripts:
                try:
                    script.unload()
                except frida.InvalidOperationError as e:
                    logger.error(e)
        
        self.frida_session.detach()

        self.package_name = None
        self.device = None
        self.frida_session = None
        self.pid = -1
        self.scripts = []
        self.pending = []
        self.list_file_scripts = []


    def print_devices(self):
        logger.info("List of devices")
        devices = frida.enumerate_devices()
        for i in range(len(devices)):
            logger.info('{}) {}'.format(i, devices[i]))

    def connect_default_usb(self):
        self.device = frida.get_usb_device()
        logger.info("Connected to device {}".format(self.device))

    def _read_scripts(self, scripts):
        data_scripts = ""

        for script_file in scripts:
            with open(script_file, 'r') as file:
                data_scripts += file.read()
                data_scripts += '\n\n'
        
        return data_scripts

    def read_scripts(self, scripts):
        return "Java.perform(function () {\n" + self._read_scripts(self.ag_scripts + scripts) + "\n" + "});"

    def install_apk(self, filename):
        adb.adb(self.device.id, "install {}".format(filename))

    def attach_package(self, package_name, list_file_scripts, pid=None):
        self.package_name = package_name

        logger.info("Starting package {} {} {}".format(package_name, list_file_scripts, pid))

        self.list_file_scripts = list_file_scripts

        self.device.on('spawn-added', self.spawn_added)
        self.device.on('spawn-removed', self.spawn_removed)
        self.device.on('child-added', self.child_added)
        self.device.on('child-removed', self.on_spawned)
        self.device.on('process-crashed', self.on_spawned)
        self.device.on('output', self.on_spawned)
        self.device.on('uninjected', self.on_spawned)
        self.device.on('lost', self.on_spawned)
        self.device.enable_spawn_gating()
        self.event = threading.Event()
        logger.info('Enabled spawn gating')

        try:
            # It is not an existing process, spawn a new one
            if not pid:
                pid = self.device.spawn([package_name])
            
            self.pid = pid
            
            self.frida_session = self.device.attach(self.pid)
            self.load_scripts(self.frida_session, list_file_scripts)
            self.frida_session.on('detached', self.on_detached)
        except frida.NotSupportedError as e:
            logger.error(e)

        

    def load_scripts(self, current_session, scripts):
        try:
            logger.info('Loading scripts {}'.format(scripts))
            script = current_session.create_script(self.read_scripts(scripts))

            script.on("message", self.androguard_message_handler)
            script.load()
            self.scripts.append(script)
        except Exception as e:
            logger.error(e)

    def run_frida(self):
        logger.info("Running frida ! Resuming the PID {}".format(self.pid))
        self.device.resume(self.pid)
      
    def androguard_message_handler(self, message, payload):
        # use for system event
        previous_stacktrace = None

        logger.debug("MESSAGE {} {}".format(message, payload))

        if message["type"] == "send":
            msg_payload = json.loads(message["payload"])
            params = {}

            if msg_payload["id"] == "AG-EVENT":
                for i in msg_payload:
                    if i not in ["id", "ret", "timestamp", "stacktrace"]:
                        params[i] = msg_payload[i]

                function_call = msg_payload["stacktrace"][0]
                function_callee = msg_payload["stacktrace"][1]
                ret_value = json.dumps(msg_payload.get("ret"))

                logger.info("{} - [{}:{}] [{}] -> [{}]".format(msg_payload["timestamp"], function_call, function_callee, params, ret_value))
                self.message_queue.put(MessageEvent(self.idx, function_call, function_callee, params, ret_value))
                self.ag_session.insert_event(call=function_call, callee=function_callee, params=params, ret=ret_value)
                previous_stacktrace = msg_payload["stacktrace"]

                self.idx += 1
            elif msg_payload["id"] == "AG-SYSTEM":
                for i in msg_payload:
                    if i not in ["id", "information", "timestamp", "stacktrace"]:
                        params[i] = msg_payload[i]

                function_call = None
                function_callee = None
                information = msg_payload["information"]

                if previous_stacktrace:
                    function_call = previous_stacktrace[0]
                    function_callee = previous_stacktrace[1]
                else:
                    function_callee = information
                
                logger.warning("{} - [{}:{}] [{}] -> [{}]".format(msg_payload["timestamp"], function_call, function_callee, information, params))
                self.message_queue.put(MessageSystem(self.idx, function_call, function_callee, params, information))
                self.ag_session.insert_system_event(call=function_call, callee=function_callee, params=params, information=information)

                if not previous_stacktrace:
                    self.idx += 1
            elif msg_payload["id"] == "AG-BINDER":
                logger.info("BINDER {} {}".format(message, payload))
                self.idx += 1

    def dump(self, package_name):
        api = self.scripts[0].exports

        matches = api.scandex()
        mds = []

        for info in matches:
            try:

                bs = api.memorydump(info['addr'], info['size'])
                md = md5(bs)
                if md in mds:
                    logger.warning("[DEXDump]: Skip duplicate dex {}<{}>".format(info['addr'], md), fg="blue")
                    continue
                mds.append(md)

                if not os.path.exists("./" + package_name + "/"):
                    os.mkdir("./" + package_name + "/")
                if bs[:4] != "dex\n":
                    bs = b"dex\n035\x00" + bs[8:]
                readable_hash = hashlib.sha256(bs).hexdigest();
                with open(package_name + "/" + readable_hash + ".dex", 'wb') as out:
                    out.write(bs)
                logger.info("[DEXDump]: DexSize={}, SavePath={}/{}/{}.dex"
                            .format(hex(info['size']), os.getcwd(), package_name, readable_hash), fg='green')
            except Exception as e:
                logger.error("[Except] - {}: {}".format(e, info), bg='yellow')

    def on_detached(self, reason):
        logger.info("Session is detached due to: {}".format(reason))
        self.detached = True
    
    def on_spawned(self, spawn):
        #logger.info('on_spawned: {}'.format(spawn))
        self.pending.append(spawn)
        self.event.set()
    
    def spawn_added(self, spawn):
        #logger.info('spawn_added: {}'.format(spawn))
        
        self.event.set()

        if(spawn.identifier.startswith(self.package_name)):
            #logger.info('added tace: {}'.format(spawn))
        
            session = self.device.attach(spawn.pid) 
            self.load_scripts(session, self.list_file_scripts)
        self.device.resume(spawn.pid)
        #logger.info('Resumed')
            
    def spawn_removed(self, spawn):
        logger.info('spawn_removed: {}'.format(spawn))
        self.event.set()

    def child_added(self, spawn):
        logger.info('child_added: {}'.format(spawn))

    def start_trace(self, filename, session, list_modules, live=False, noapk=False, dump=False):
        self.ag_session = session

        logger.info("Start to trace {} {}".format(filename, list_modules))

        if not live:
            apk_obj, dex_objs, dx_obj = session.get_objects_apk(filename)
            if not apk_obj:
                logger.error("Can't find any APK object associated")
                return
        
        if not self.device:
            logger.error("Not connected to any device yet")
            return

        list_scripts_to_load  = []
        for new_module in list_modules:
            if '*' in new_module:
                for i_module in glob.iglob(new_module, recursive=True):
                    if os.path.isfile(i_module):
                        if "disable_" not in i_module:
                            list_scripts_to_load.append(i_module)
            else:
                list_scripts_to_load.append(new_module)

        package_name = ""
        pid = None
        if not live:
            self.install_apk(filename)
            package_name = apk_obj.get_package()
        else:
            package_name = filename
            pid_value = os.popen("adb -s {} shell pidof {}".format(self.device.id, package_name)).read().strip()
            if pid_value:
                pid = int(pid_value)

        self.attach_package(package_name, list_scripts_to_load, pid)

        if not dump:
            self.run_frida()
        else:
            self.dump(package_name)